{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Net.\n",
    "\n",
    "> Create and tune pytorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "# from nbdev.showdoc import *\n",
    "from fastcore.test import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "import torch.nn as nn\n",
    "import sys, torch\n",
    "from functools import partial\n",
    "from collections import OrderedDict\n",
    "\n",
    "from model_constructor.layers import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "act_fn = nn.ReLU(inplace=True)\n",
    "\n",
    "def init_cnn(m):\n",
    "    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)\n",
    "    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)\n",
    "    for l in m.children(): init_cnn(l)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ResBlock"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class ResBlock(nn.Module):\n",
    "    def __init__(self, expansion, ni, nh, stride=1, \n",
    "                 conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,\n",
    "                 pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False):\n",
    "        super().__init__()\n",
    "        nf,ni = nh*expansion,ni*expansion\n",
    "        layers  = [(f\"conv_0\", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),\n",
    "                   (f\"conv_1\", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))\n",
    "        ] if expansion == 1 else [\n",
    "                   (f\"conv_0\",conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),\n",
    "                   (f\"conv_1\",conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),\n",
    "                   (f\"conv_2\",conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))\n",
    "        ]\n",
    "        if sa: layers.append(('sa', SimpleSelfAttention(nf,ks=1,sym=sym)))\n",
    "        self.convs = nn.Sequential(OrderedDict(layers))\n",
    "        self.pool = noop if stride==1 else pool\n",
    "        self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)\n",
    "        self.act_fn =act_fn\n",
    "\n",
    "    def forward(self, x): return self.act_fn(self.convs(x) + self.idconv(self.pool(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResBlock(\n",
       "  (convs): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (sa): SimpleSelfAttention(\n",
       "      (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    )\n",
       "  )\n",
       "  (act_fn): ReLU(inplace=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ResBlock(1,64,64,sa=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResBlock(\n",
       "  (convs): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv_2): ConvLayer(\n",
       "      (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (act_fn): LeakyReLU(negative_slope=0.01)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ResBlock(2,64,64,act_fn=nn.LeakyReLU(), bn_1st=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NewResBlock"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "# Still no name - just New block YET!\n",
    "class NewResBlock(nn.Module):\n",
    "    def __init__(self, expansion, ni, nh, stride=1, \n",
    "                 conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,\n",
    "                 pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False):\n",
    "        super().__init__()\n",
    "        nf,ni = nh*expansion,ni*expansion\n",
    "        self.reduce = noop if stride==1 else pool\n",
    "        layers  = [(f\"conv_0\", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st)), # stride 1 !!!\n",
    "                   (f\"conv_1\", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))\n",
    "        ] if expansion == 1 else [\n",
    "                   (f\"conv_0\",conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),\n",
    "                   (f\"conv_1\",conv_layer(nh, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st)), # stride 1 !!!\n",
    "                   (f\"conv_2\",conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))\n",
    "        ]\n",
    "        if sa: layers.append(('sa', SimpleSelfAttention(nf,ks=1,sym=sym)))\n",
    "        self.convs = nn.Sequential(OrderedDict(layers))\n",
    "        self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)\n",
    "        self.merge =act_fn\n",
    "\n",
    "    def forward(self, x): \n",
    "        o = self.reduce(x)\n",
    "        return self.merge(self.convs(o) + self.idconv(o))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NewResBlock(\n",
       "  (convs): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (sa): SimpleSelfAttention(\n",
       "      (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    )\n",
       "  )\n",
       "  (merge): ReLU(inplace=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "bl = NewResBlock(1,64,64,sa=True)\n",
    "bl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 64, 32, 32])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 64, 32, 32)\n",
    "y = bl(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 64, 32, 32]), f\"size\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NewResBlock(\n",
       "  (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "  (convs): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv_2): ConvLayer(\n",
       "      (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (idconv): ConvLayer(\n",
       "    (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (merge): LeakyReLU(negative_slope=0.01)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "bl = NewResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False)\n",
    "bl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 512, 16, 16])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 256, 32, 32)\n",
    "y = bl(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 512, 16, 16]), f\"size\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stem, Body, Head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def _make_stem(self):\n",
    "        stem = [(f\"conv_{i}\", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1], \n",
    "                    stride=2 if i==0 else 1, \n",
    "                    bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,\n",
    "                    act_fn=self.act_fn, bn_1st=self.bn_1st))\n",
    "                for i in range(len(self.stem_sizes)-1)]\n",
    "        stem.append(('stem_pool', self.stem_pool))\n",
    "        if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))\n",
    "        return nn.Sequential(OrderedDict(stem))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def _make_layer(self,expansion,ni,nf,blocks,stride,sa):\n",
    "        return nn.Sequential(OrderedDict(\n",
    "            [(f\"bl_{i}\", self.block(expansion, ni if i==0 else nf, nf, \n",
    "                    stride if i==0 else 1, sa=sa if i==blocks-1 else False,\n",
    "                    conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,\n",
    "                                    zero_bn=self.zero_bn, bn_1st=self.bn_1st))\n",
    "              for i in range(blocks)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def _make_body(self):\n",
    "        blocks = [(f\"l_{i}\", self._make_layer(self,self.expansion, \n",
    "                        self.block_szs[i], self.block_szs[i+1], l, \n",
    "                        1 if i==0 else 2, self.sa if i==0 else False))\n",
    "                  for i,l in enumerate(self.layers)]\n",
    "        return nn.Sequential(OrderedDict(blocks))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def _make_head(self):\n",
    "        head = [('pool', nn.AdaptiveAvgPool2d(1)),\n",
    "                ('flat', Flatten()),\n",
    "                ('fc',   nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]\n",
    "        return nn.Sequential(OrderedDict(head))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Net class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "# v8\n",
    "class Net():\n",
    "    def __init__(self, expansion=1, layers=[2,2,2,2], c_in=3, c_out=1000, name='Net'):\n",
    "        super().__init__()\n",
    "        self.name = name\n",
    "        self.c_in, self.c_out,self.expansion,self.layers = c_in,c_out,expansion,layers # todo setter for expansion\n",
    "        self.stem_sizes = [c_in,32,32,64]\n",
    "        self.stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        self.stem_bn_end = False\n",
    "        self.block = ResBlock\n",
    "        self.norm = nn.BatchNorm2d\n",
    "        self.act_fn=nn.ReLU(inplace=True)\n",
    "        self.pool = nn.AvgPool2d(2, ceil_mode=True)\n",
    "        self.sa=False\n",
    "        self.bn_1st = True\n",
    "        self.zero_bn=True\n",
    "        self.conv_layer = ConvLayer\n",
    "        self._init_cnn = init_cnn\n",
    "        self._make_stem = _make_stem\n",
    "        self._make_layer = _make_layer\n",
    "        self._make_body = _make_body\n",
    "        self._make_head = _make_head\n",
    "        \n",
    "        \n",
    "    @property\n",
    "    def block_szs(self):\n",
    "        return [64//self.expansion,64,128,256,512] +[256]*(len(self.layers)-4) \n",
    "\n",
    "    @property\n",
    "    def stem(self):\n",
    "        return self._make_stem(self)\n",
    "    @property\n",
    "    def head(self):\n",
    "        return self._make_head(self)\n",
    "#     @property\n",
    "#     def _make_layer(self):\n",
    "#         return self.__make_layer(self)\n",
    "    @property\n",
    "    def body(self):\n",
    "        return self._make_body(self)\n",
    "    \n",
    "#     def _make_stem(self):\n",
    "#         stem = [(f\"conv_{i}\", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1], \n",
    "#                     stride=2 if i==0 else 1, \n",
    "#                     bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,\n",
    "#                     act_fn=self.act_fn, bn_1st=self.bn_1st))\n",
    "#                 for i in range(len(self.stem_sizes)-1)]\n",
    "#         stem.append(('stem_pool', self.stem_pool))\n",
    "#         if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))\n",
    "#         return nn.Sequential(OrderedDict(stem))\n",
    "    \n",
    "#     def _make_head(self):\n",
    "#         head = [('pool', nn.AdaptiveAvgPool2d(1)),\n",
    "#                 ('flat', Flatten()),\n",
    "#                 ('fc',   nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]\n",
    "#         return nn.Sequential(OrderedDict(head))\n",
    "    \n",
    "#     def _make_body(self):\n",
    "#         blocks = [(f\"l_{i}\", self._make_layer(self.expansion, \n",
    "#                         self.block_szs[i], self.block_szs[i+1], l, \n",
    "#                         1 if i==0 else 2, self.sa if i==0 else False))\n",
    "#                   for i,l in enumerate(self.layers)]\n",
    "#         return nn.Sequential(OrderedDict(blocks))\n",
    "    \n",
    "#     def _make_layer(self,expansion,ni,nf,blocks,stride,sa):\n",
    "#         return nn.Sequential(OrderedDict(\n",
    "#             [(f\"bl_{i}\", self.block(expansion, ni if i==0 else nf, nf, \n",
    "#                     stride if i==0 else 1, sa=sa if i==blocks-1 else False,\n",
    "#                     conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,\n",
    "#                                     zero_bn=self.zero_bn, bn_1st=self.bn_1st))\n",
    "#               for i in range(blocks)]))\n",
    "    \n",
    "    def __call__(self):\n",
    "        model = nn.Sequential(OrderedDict([\n",
    "            ('stem', self.stem),\n",
    "            ('body', self.body),\n",
    "            ('head', self.head)\n",
    "        ]))\n",
    "        self._init_cnn(model)\n",
    "        model.extra_repr = lambda : f\"model {self.name}\"\n",
    "        return model\n",
    "    def __repr__(self):\n",
    "        return f\" constr {self.name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model  = Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       " constr Net"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[64, 64, 128, 256, 512]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.block_szs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv_1): ConvLayer(\n",
       "    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv_2): ConvLayer(\n",
       "    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 64, 32, 32])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 3, 128, 128)\n",
    "y = model.stem(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 64, 32, 32]), f\"size\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.bn_1st = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.act_fn =nn.LeakyReLU(inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.sa = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (bl_0): ResBlock(\n",
       "    (convs): Sequential(\n",
       "      (conv_0): ConvLayer(\n",
       "        (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv_1): ConvLayer(\n",
       "        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "    (idconv): ConvLayer(\n",
       "      (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "  )\n",
       "  (bl_1): ResBlock(\n",
       "    (convs): Sequential(\n",
       "      (conv_0): ConvLayer(\n",
       "        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (conv_1): ConvLayer(\n",
       "        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.body.l_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 64, 32, 32])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 64, 32, 32)\n",
    "y = model.body.l_0(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 64, 32, 32]), f\"size\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.block = NewResBlock"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 128, 16, 16])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 64, 32, 32)\n",
    "y = model.body.l_1.bl_0(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 128, 16, 16]), f\"size\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NewResBlock(\n",
       "  (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "  (convs): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (idconv): ConvLayer(\n",
       "    (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "model.body.l_1.bl_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = Net(expansion=4)\n",
    "model.expansion = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.stem_bn_end = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (conv_1): ConvLayer(\n",
       "    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (conv_2): ConvLayer(\n",
       "    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "  )\n",
       "  (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NewResBlock(\n",
       "  (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "  (convs): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv_2): ConvLayer(\n",
       "      (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "  )\n",
       "  (idconv): ConvLayer(\n",
       "    (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "model.body.l_1.bl_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 512, 16, 16])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 256, 32, 32)\n",
    "y = model.body.l_1.bl_0(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 512, 16, 16]), f\"size\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  model Net\n",
       "  (stem): Sequential(\n",
       "    (conv_0): ConvLayer(\n",
       "      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv_1): ConvLayer(\n",
       "      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv_2): ConvLayer(\n",
       "      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "    )\n",
       "    (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "    (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (body): Sequential(\n",
       "    (l_0): Sequential(\n",
       "      (bl_0): NewResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      )\n",
       "      (bl_1): NewResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (sa): SimpleSelfAttention(\n",
       "            (conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)\n",
       "          )\n",
       "        )\n",
       "        (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (l_1): Sequential(\n",
       "      (bl_0): NewResBlock(\n",
       "        (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      )\n",
       "      (bl_1): NewResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (l_2): Sequential(\n",
       "      (bl_0): NewResBlock(\n",
       "        (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      )\n",
       "      (bl_1): NewResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (l_3): Sequential(\n",
       "      (bl_0): NewResBlock(\n",
       "        (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      )\n",
       "      (bl_1): NewResBlock(\n",
       "        (convs): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Sequential(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=1)\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 1000])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 3, 128, 128)\n",
    "y = m(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 1000]), f\"size expected {bs_test}, 1000\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.stem_sizes = [3, 32, 32, 64, 64]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (conv_1): ConvLayer(\n",
       "    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (conv_2): ConvLayer(\n",
       "    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (conv_3): ConvLayer(\n",
       "    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "  )\n",
       "  (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "model.stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (pool): AdaptiveAvgPool2d(output_size=1)\n",
       "  (flat): Flatten()\n",
       "  (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide\n",
    "model.head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## xresnet constructor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "me = sys.modules[__name__]\n",
    "for n,e,l in [[ 18 , 1, [2,2,2 ,2] ],\n",
    "    [ 34 , 1, [3,4,6 ,3] ],\n",
    "    [ 50 , 4, [3,4,6 ,3] ],\n",
    "    [ 101, 4, [3,4,23,3] ],\n",
    "    [ 152, 4, [3,8,36,3] ],]:\n",
    "    name = f'net{n}'\n",
    "    setattr(me, name, partial(Net, expansion=e, layers=l, name=name))\n",
    "xresnet34  = partial(Net, expansion=1, layers=[3, 4,  6, 3], name='xresnet34')\n",
    "xresnet50  = partial(Net, expansion=4, layers=[3, 4,  6, 3], name='xresnet50')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = xresnet50(c_out=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "( constr xresnet50, 10)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m, m.c_out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# end\n",
    "model_constructor\n",
    "by ayasyrev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_constructor.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_resnet.ipynb.\n",
      "Converted 03_xresnet.ipynb.\n",
      "Converted 04_Net.ipynb.\n",
      "Converted 80_test_net.ipynb.\n",
      "Converted 81_Net.ipynb.\n",
      "Converted 81_test_xresnet.ipynb.\n",
      "Converted index.ipynb.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "from nbdev.export import *\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
